{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "matchzoo version 1.0\n",
      "`ranking_task` initialized with metrics [normalized_discounted_cumulative_gain@3(0.0), normalized_discounted_cumulative_gain@5(0.0), mean_average_precision(0.0)]\n",
      "data loading ...\n",
      "data loaded as `train_pack_raw` `dev_pack_raw` `test_pack_raw`\n"
     ]
    }
   ],
   "source": [
    "%run init.ipynb"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "preprocessor = mz.models.CDSSM.get_default_preprocessor(\n",
    "    ngram_size = 3\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Processing text_left with chain_transform of Tokenize => Lowercase => PuncRemoval: 100%|██████████| 2118/2118 [00:00<00:00, 8404.73it/s]\n",
      "Processing text_right with chain_transform of Tokenize => Lowercase => PuncRemoval: 100%|██████████| 18841/18841 [00:04<00:00, 4333.11it/s]\n",
      "Processing text_right with append: 100%|██████████| 18841/18841 [00:00<00:00, 843570.94it/s]\n",
      "Building FrequencyFilter from a datapack.: 100%|██████████| 18841/18841 [00:00<00:00, 128272.36it/s]\n",
      "Processing text_right with transform: 100%|██████████| 18841/18841 [00:00<00:00, 139975.31it/s]\n",
      "Processing text_left with extend: 100%|██████████| 2118/2118 [00:00<00:00, 640485.64it/s]\n",
      "Processing text_right with extend: 100%|██████████| 18841/18841 [00:00<00:00, 692587.11it/s]\n",
      "Building Vocabulary from a datapack.: 100%|██████████| 418401/418401 [00:00<00:00, 2541596.17it/s]\n",
      "Processing text_left with transform: 100%|██████████| 2118/2118 [00:00<00:00, 38102.56it/s]\n",
      "Processing text_right with transform: 100%|██████████| 18841/18841 [00:01<00:00, 15164.94it/s]\n",
      "Processing text_left with extend: 100%|██████████| 2118/2118 [00:00<00:00, 514808.52it/s]\n",
      "Processing text_right with extend: 100%|██████████| 18841/18841 [00:00<00:00, 342142.01it/s]\n",
      "Building Vocabulary from a datapack.: 100%|██████████| 2082963/2082963 [00:00<00:00, 2731408.69it/s]\n",
      "Processing text_left with chain_transform of Tokenize => Lowercase => PuncRemoval: 100%|██████████| 2118/2118 [00:00<00:00, 8582.66it/s]\n",
      "Processing text_right with chain_transform of Tokenize => Lowercase => PuncRemoval: 100%|██████████| 18841/18841 [00:04<00:00, 4417.94it/s]\n",
      "Processing text_right with transform: 100%|██████████| 18841/18841 [00:00<00:00, 126014.39it/s]\n",
      "Processing text_left with transform: 100%|██████████| 2118/2118 [00:00<00:00, 212463.79it/s]\n",
      "Processing text_right with transform: 100%|██████████| 18841/18841 [00:00<00:00, 107616.56it/s]\n",
      "Processing length_left with len: 100%|██████████| 2118/2118 [00:00<00:00, 613376.78it/s]\n",
      "Processing length_right with len: 100%|██████████| 18841/18841 [00:00<00:00, 815084.44it/s]\n",
      "Processing text_left with chain_transform of Tokenize => Lowercase => PuncRemoval: 100%|██████████| 122/122 [00:00<00:00, 8473.20it/s]\n",
      "Processing text_right with chain_transform of Tokenize => Lowercase => PuncRemoval: 100%|██████████| 1115/1115 [00:00<00:00, 4559.90it/s]\n",
      "Processing text_right with transform: 100%|██████████| 1115/1115 [00:00<00:00, 136588.36it/s]\n",
      "Processing text_left with transform: 100%|██████████| 122/122 [00:00<00:00, 103103.99it/s]\n",
      "Processing text_right with transform: 100%|██████████| 1115/1115 [00:00<00:00, 129475.33it/s]\n",
      "Processing length_left with len: 100%|██████████| 122/122 [00:00<00:00, 171482.94it/s]\n",
      "Processing length_right with len: 100%|██████████| 1115/1115 [00:00<00:00, 587076.19it/s]\n",
      "Processing text_left with chain_transform of Tokenize => Lowercase => PuncRemoval: 100%|██████████| 237/237 [00:00<00:00, 8762.86it/s]\n",
      "Processing text_right with chain_transform of Tokenize => Lowercase => PuncRemoval: 100%|██████████| 2300/2300 [00:00<00:00, 4477.27it/s]\n",
      "Processing text_right with transform: 100%|██████████| 2300/2300 [00:00<00:00, 124450.43it/s]\n",
      "Processing text_left with transform: 100%|██████████| 237/237 [00:00<00:00, 134186.02it/s]\n",
      "Processing text_right with transform: 100%|██████████| 2300/2300 [00:00<00:00, 117824.72it/s]\n",
      "Processing length_left with len: 100%|██████████| 237/237 [00:00<00:00, 314075.84it/s]\n",
      "Processing length_right with len: 100%|██████████| 2300/2300 [00:00<00:00, 731879.16it/s]\n"
     ]
    }
   ],
   "source": [
    "train_pack_processed = preprocessor.fit_transform(train_pack_raw)\n",
    "valid_pack_processed = preprocessor.transform(dev_pack_raw)\n",
    "test_pack_processed = preprocessor.transform(test_pack_raw)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'ngram_process_unit': <matchzoo.preprocessors.units.ngram_letter.NgramLetter at 0x7fed5ef08630>,\n",
       " 'filter_unit': <matchzoo.preprocessors.units.frequency_filter.FrequencyFilter at 0x7fed5ef08588>,\n",
       " 'vocab_unit': <matchzoo.preprocessors.units.vocabulary.Vocabulary at 0x7fed5ef57160>,\n",
       " 'vocab_size': 30058,\n",
       " 'embedding_input_dim': 30058,\n",
       " 'ngram_vocab_unit': <matchzoo.preprocessors.units.vocabulary.Vocabulary at 0x7fec8ce58080>,\n",
       " 'ngram_vocab_size': 9654}"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "preprocessor.context"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "triletter_callback = mz.dataloader.callbacks.Ngram(preprocessor, mode='sum')\n",
    "trainset = mz.dataloader.Dataset(\n",
    "    data_pack=train_pack_processed,\n",
    "    mode='pair',\n",
    "    num_dup=2,\n",
    "    num_neg=1,\n",
    "    callbacks=[triletter_callback]\n",
    ")\n",
    "testset = mz.dataloader.Dataset(\n",
    "    data_pack=test_pack_processed,\n",
    "    callbacks=[triletter_callback]\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "padding_callback = mz.models.CDSSM.get_default_padding_callback(\n",
    "    with_ngram=True,\n",
    "    fixed_ngram_length=preprocessor.context['ngram_vocab_size']\n",
    ")\n",
    "\n",
    "trainloader = mz.dataloader.DataLoader(\n",
    "    dataset=trainset,\n",
    "    batch_size=20,\n",
    "    stage='train',\n",
    "    sort=False,\n",
    "    resample=True,\n",
    "    callback=padding_callback\n",
    ")\n",
    "testloader = mz.dataloader.DataLoader(\n",
    "    dataset=testset,\n",
    "    batch_size=20,\n",
    "    stage='dev',\n",
    "    sort=False,\n",
    "    callback=padding_callback\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CDSSM(\n",
      "  (net_left): Sequential(\n",
      "    (0): ConstantPad1d(padding=(0, 2), value=0)\n",
      "    (1): Conv1d(9654, 64, kernel_size=(3,), stride=(1,))\n",
      "    (2): Tanh()\n",
      "    (3): Dropout(p=0.8, inplace=False)\n",
      "    (4): AdaptiveMaxPool1d(output_size=1)\n",
      "    (5): Squeeze()\n",
      "    (6): Sequential(\n",
      "      (0): Sequential(\n",
      "        (0): Linear(in_features=64, out_features=64, bias=True)\n",
      "        (1): Tanh()\n",
      "      )\n",
      "      (1): Sequential(\n",
      "        (0): Linear(in_features=64, out_features=64, bias=True)\n",
      "        (1): Tanh()\n",
      "      )\n",
      "    )\n",
      "  )\n",
      "  (net_right): Sequential(\n",
      "    (0): ConstantPad1d(padding=(0, 2), value=0)\n",
      "    (1): Conv1d(9654, 64, kernel_size=(3,), stride=(1,))\n",
      "    (2): Tanh()\n",
      "    (3): Dropout(p=0.8, inplace=False)\n",
      "    (4): AdaptiveMaxPool1d(output_size=1)\n",
      "    (5): Squeeze()\n",
      "    (6): Sequential(\n",
      "      (0): Sequential(\n",
      "        (0): Linear(in_features=64, out_features=64, bias=True)\n",
      "        (1): Tanh()\n",
      "      )\n",
      "      (1): Sequential(\n",
      "        (0): Linear(in_features=64, out_features=64, bias=True)\n",
      "        (1): Tanh()\n",
      "      )\n",
      "    )\n",
      "  )\n",
      "  (out): Linear(in_features=1, out_features=1, bias=True)\n",
      ")\n",
      "Trainable params:  3723906\n"
     ]
    }
   ],
   "source": [
    "model = mz.models.CDSSM()\n",
    "\n",
    "model.params['task'] = ranking_task\n",
    "model.params['vocab_size'] = preprocessor.context['ngram_vocab_size']\n",
    "model.params['filters'] = 64\n",
    "model.params['kernel_size'] = 3\n",
    "model.params['conv_activation_func'] = 'tanh'\n",
    "model.params['mlp_num_layers'] = 1\n",
    "model.params['mlp_num_units'] = 64\n",
    "model.params['mlp_num_fan_out'] = 64\n",
    "model.params['mlp_activation_func'] = 'tanh'\n",
    "model.params['dropout_rate'] = 0.8\n",
    "\n",
    "model.build()\n",
    "\n",
    "print(model)\n",
    "print('Trainable params: ', sum(p.numel() for p in model.parameters() if p.requires_grad))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "optimizer = torch.optim.Adadelta(model.parameters())\n",
    "\n",
    "trainer = mz.trainers.Trainer(\n",
    "    model=model,\n",
    "    optimizer=optimizer,\n",
    "    trainloader=trainloader,\n",
    "    validloader=testloader,\n",
    "    validate_interval=None,\n",
    "    epochs=10\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "d516ecd721d5498d9b64dac2edad9ec3",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(IntProgress(value=0, max=102), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[Iter-102 Loss-0.820]:\n",
      "  Validation: normalized_discounted_cumulative_gain@3(0.0): 0.5302 - normalized_discounted_cumulative_gain@5(0.0): 0.5922 - mean_average_precision(0.0): 0.5508\n",
      "\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "04cc2cc860024c2aa1eb0c1d6d0d7522",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(IntProgress(value=0, max=102), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[Iter-204 Loss-0.636]:\n",
      "  Validation: normalized_discounted_cumulative_gain@3(0.0): 0.5696 - normalized_discounted_cumulative_gain@5(0.0): 0.628 - mean_average_precision(0.0): 0.5867\n",
      "\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "ad66e908e6f34685a1690b46f31d4d70",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(IntProgress(value=0, max=102), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[Iter-306 Loss-0.500]:\n",
      "  Validation: normalized_discounted_cumulative_gain@3(0.0): 0.5776 - normalized_discounted_cumulative_gain@5(0.0): 0.6344 - mean_average_precision(0.0): 0.5922\n",
      "\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "a0cc82172abf4d20945ca220e9feaf36",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(IntProgress(value=0, max=102), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[Iter-408 Loss-0.425]:\n",
      "  Validation: normalized_discounted_cumulative_gain@3(0.0): 0.5808 - normalized_discounted_cumulative_gain@5(0.0): 0.6469 - mean_average_precision(0.0): 0.6014\n",
      "\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "217246bea66e438eb3b4d2fa54fa7491",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(IntProgress(value=0, max=102), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[Iter-510 Loss-0.333]:\n",
      "  Validation: normalized_discounted_cumulative_gain@3(0.0): 0.5925 - normalized_discounted_cumulative_gain@5(0.0): 0.6515 - mean_average_precision(0.0): 0.6139\n",
      "\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "02d42d72fe894ae4908765a298d8fce3",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(IntProgress(value=0, max=102), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[Iter-612 Loss-0.307]:\n",
      "  Validation: normalized_discounted_cumulative_gain@3(0.0): 0.5773 - normalized_discounted_cumulative_gain@5(0.0): 0.6491 - mean_average_precision(0.0): 0.6076\n",
      "\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "3c6c84eecf614badaee330d407043ea6",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(IntProgress(value=0, max=102), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[Iter-714 Loss-0.289]:\n",
      "  Validation: normalized_discounted_cumulative_gain@3(0.0): 0.581 - normalized_discounted_cumulative_gain@5(0.0): 0.6422 - mean_average_precision(0.0): 0.6075\n",
      "\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "3d54d0052a7545a9a143a6297134e08e",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(IntProgress(value=0, max=102), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[Iter-816 Loss-0.238]:\n",
      "  Validation: normalized_discounted_cumulative_gain@3(0.0): 0.5709 - normalized_discounted_cumulative_gain@5(0.0): 0.6393 - mean_average_precision(0.0): 0.5938\n",
      "\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "afdc8ede033546f381df505e019bde0b",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(IntProgress(value=0, max=102), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[Iter-918 Loss-0.220]:\n",
      "  Validation: normalized_discounted_cumulative_gain@3(0.0): 0.5683 - normalized_discounted_cumulative_gain@5(0.0): 0.6377 - mean_average_precision(0.0): 0.5975\n",
      "\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "73af62944c4e48c49bbac68cd6e1f8be",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(IntProgress(value=0, max=102), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[Iter-1020 Loss-0.212]:\n",
      "  Validation: normalized_discounted_cumulative_gain@3(0.0): 0.5781 - normalized_discounted_cumulative_gain@5(0.0): 0.6427 - mean_average_precision(0.0): 0.6004\n",
      "\n",
      "Cost time: 3756.318562269211s\n"
     ]
    }
   ],
   "source": [
    "trainer.run()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
